# --------------------------------------------------------
# Licensed under The MIT License
# Based on timm and DeiT code bases
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/facebookresearch/deit/
# --------------------------------------------------------'
import math
from operator import mod
from numpy.core.fromnumeric import std
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import torch.utils.checkpoint as checkpoint

from functools import partial

from modeling_finetune import Block, LayerNormWithForceFP32, _cfg, PatchEmbed, RelativePositionBias
from timm.models.registry import model_entrypoint, register_model
from timm.models.layers import trunc_normal_ as __call_trunc_normal_

from dall_e.utils import unmap_pixels
from torchvision import models as torchvision_models

from einops import rearrange

from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from modeling_pretrain_cnn import CNNForMaskedImageModeling

def trunc_normal_(tensor, mean=0., std=1.):
    __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std)


__all__ = [
    'beit_base_patch16_224_8k_vocab',
    'beit_large_patch16_224_8k_vocab',
    'beit_e_base_patch16_224_voc8k',
]


def log(t, eps=1e-9):
    return torch.log(t + eps)


def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))


def gumbel_sample(t, temperature=1.):
    t = t.float()
    return ((t / temperature) + gumbel_noise(t)).argmax(dim=-1)


class VisionTransformerForMaskedImageModeling(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, discrete_vae_type="dall-e", vocab_size=8192, embed_dim=768, depth=12,
                 num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
                 drop_path_rate=0., norm_layer=None, init_values=None, attn_head_dim=None,
                 use_abs_pos_emb=False, use_rel_pos_bias=False, use_shared_rel_pos_bias=True, init_std=0.02,
                 gen_dim_ratio=2, gen_depth=4, disable_gen_amp=False, disable_interpolate=False, share_num_layers=0, 
                 random_replace=False, tokenizer_stride=8, keep_ratio=0., dis_loss_weight=50, gumbel_sample_temperature=1., 
                 reg_diff=False, both_cls_reg=False, non_overlapping_win_pixel_norm=False, recon_ori=False, recon_unnorm_ori=False, vis_gen=False, share_layers_sg=False):
        super().__init__()

        self.disable_gen_amp = disable_gen_amp
        self.disable_interpolate = disable_interpolate
        self.discrete_vae_type = discrete_vae_type
        self.vocab_size = vocab_size
        self.random_replace = random_replace
        self.tokenizer_stride = tokenizer_stride
        self.keep_ratio = keep_ratio
        self.dis_loss_weight = dis_loss_weight
        self.reg_diff = reg_diff
        self.patch_size = patch_size
        self.both_cls_reg = both_cls_reg
        self.non_overlapping_win_pixel_norm = non_overlapping_win_pixel_norm
        self.recon_ori = recon_ori
        self.recon_unnorm_ori = recon_unnorm_ori
        self.vis_gen = vis_gen
        self.gumbel_sample_temperature = gumbel_sample_temperature
        self.share_layers_sg = share_layers_sg
        self.share_num_layers = share_num_layers

        # num_features for consistency with other models
        self.num_features = self.embed_dim = embed_dim

        embed_dim_gen = embed_dim // gen_dim_ratio

        self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
        self.patch_embed_gen = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim_gen)

        num_patches_gen = self.patch_embed_gen.num_patches
        num_patches = self.patch_embed.num_patches

        self.cls_token_gen = nn.Parameter(torch.zeros(1, 1, embed_dim_gen))
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

        self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim_gen))

        if use_abs_pos_emb:
            self.pos_embed_gen = nn.Parameter(
                torch.zeros(1, num_patches_gen + 1, embed_dim_gen))
            self.pos_embed = nn.Parameter(
                torch.zeros(1, num_patches + 1, embed_dim))
        else:
            self.pos_embed = None
            self.pos_embed_gen = None

        self.pos_drop = nn.Dropout(p=drop_rate)

        if use_shared_rel_pos_bias:
            self.rel_pos_bias_gen = RelativePositionBias(
                window_size=self.patch_embed_gen.patch_shape, num_heads=num_heads // gen_dim_ratio)
            self.rel_pos_bias = RelativePositionBias(
                window_size=self.patch_embed.patch_shape, num_heads=num_heads)
        else:
            self.rel_pos_bias_gen = None
            self.rel_pos_bias = None

        dpr_gen = [x.item() for x in torch.linspace(0, 0. / (depth / gen_depth), gen_depth)]  # stochastic depth decay rule
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]

        self.blocks_gen = nn.ModuleList([
            Block(
                dim=embed_dim_gen, num_heads=num_heads // gen_dim_ratio, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                drop=0., attn_drop=0., drop_path=dpr_gen[i], norm_layer=norm_layer,
                init_values=init_values, window_size=self.patch_embed_gen.patch_shape if use_rel_pos_bias else None,
                attn_head_dim=attn_head_dim,
            )
            for i in range(gen_depth)])

        self.blocks = nn.ModuleList([
            Block(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
                init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None,
                attn_head_dim=attn_head_dim,
            )
            for i in range(depth)])

        self.norm_gen = norm_layer(embed_dim_gen)
        self.norm = norm_layer(embed_dim)

        self.init_std = init_std

        if self.disable_interpolate:
            self.gen_head = nn.Linear(embed_dim_gen, 4 * vocab_size)
            self.dis_head = nn.Linear(embed_dim, 4)
        else:
            self.gen_head = nn.Linear(embed_dim_gen, vocab_size)
            self.dis_head = nn.Linear(embed_dim, 1)

        if self.reg_diff:
            self.dis_head = nn.Linear(embed_dim, patch_size * patch_size * 3)   # 16 x 16 x 3   B, L, 16 x 16 x 3
            
        if self.both_cls_reg:
            self.dis_head_cls = nn.Linear(embed_dim, 4)
            trunc_normal_(self.dis_head_cls.weight, std=self.init_std)

        if self.pos_embed is not None:
            trunc_normal_(self.pos_embed_gen, std=self.init_std)
            trunc_normal_(self.pos_embed, std=self.init_std)

        trunc_normal_(self.cls_token_gen, std=self.init_std)
        trunc_normal_(self.cls_token, std=self.init_std)
        trunc_normal_(self.mask_token, std=self.init_std)

        trunc_normal_(self.gen_head.weight, std=self.init_std)
        trunc_normal_(self.dis_head.weight, std=self.init_std)

        self.apply(self._init_weights)
        self.fix_init_weight()

        if gen_dim_ratio == 1:
            if share_num_layers > 0:
                self.patch_embed_gen = self.patch_embed
                self.pos_embed_gen = self.pos_embed
            if share_num_layers > 1:
                for i in range(share_num_layers - 1):
                    self.blocks_gen[i] = self.blocks[i]


    def fix_init_weight(self):
        def rescale(param, layer_id):
            param.div_(math.sqrt(2.0 * layer_id))

        for layer_id, layer in enumerate(self.blocks_gen):
            rescale(layer.attn.proj.weight.data, layer_id + 1)
            rescale(layer.mlp.fc2.weight.data, layer_id + 1)

        for layer_id, layer in enumerate(self.blocks):
            rescale(layer.attn.proj.weight.data, layer_id + 1)
            rescale(layer.mlp.fc2.weight.data, layer_id + 1)


    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=self.init_std)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            trunc_normal_(m.weight, std=self.init_std)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'pos_embed', 'cls_token_gen', 'cls_token', 'mask_token', 'pos_embed_gen'}

    def get_num_layers(self):
        return len(self.blocks)

    def forward_features(self, x, bool_masked_pos):
        x = self.patch_embed_gen(x, bool_masked_pos=bool_masked_pos)
        batch_size, seq_len, _ = x.size()

        # stole cls_tokens impl from Phil Wang, thanks
        cls_tokens = self.cls_token_gen.expand(batch_size, -1, -1)
        mask_token = self.mask_token.expand(batch_size, seq_len, -1)

        # replace the masked visual tokens by mask_token
        w = bool_masked_pos.unsqueeze(-1).type_as(mask_token)
        x = x * (1 - w) + mask_token * w

        x = torch.cat((cls_tokens, x), dim=1)

        if self.pos_embed_gen is not None:
            x = x + self.pos_embed_gen
        x = self.pos_drop(x)

        rel_pos_bias = self.rel_pos_bias_gen() if self.rel_pos_bias_gen is not None else None
        for blk in self.blocks_gen:
            # x = checkpoint.checkpoint(blk, x, rel_pos_bias)    # saves mem, takes time
            x = blk(x, rel_pos_bias=rel_pos_bias)

        return self.norm_gen(x)

    def forward_features_dis(self, x):
        
        if self.share_layers_sg and self.share_num_layers > 0:
            with torch.no_grad():
                x = self.patch_embed(x)
        else:
            x = self.patch_embed(x)

        batch_size, _, _ = x.size()
        # stole cls_tokens impl from Phil Wang, thanks
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)

        x = torch.cat((cls_tokens, x), dim=1)

        if self.pos_embed is not None:
            x = x + self.pos_embed
        x = self.pos_drop(x)

        rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None

        if self.share_layers_sg and self.share_num_layers > 1:
            with torch.no_grad():
                for blk in self.blocks[:self.share_num_layers]:
                    x = blk(x, rel_pos_bias=rel_pos_bias)
            for blk in self.blocks[self.share_num_layers:]:
                x = blk(x, rel_pos_bias=rel_pos_bias)
        else:
            for blk in self.blocks:
                x = blk(x, rel_pos_bias=rel_pos_bias)

        x = self.norm(x)

        if self.both_cls_reg:
            return self.dis_head(x[:, 1:]), self.dis_head_cls(x[:, 1:])
        else:
            return self.dis_head(x[:, 1:]), None     # B, N, 1

    def forward(self, vit_input, vae_input, bool_masked_pos, d_vae, return_all_tokens=False):
    
        with torch.no_grad():
            with torch.cuda.amp.autocast():
                if self.discrete_vae_type == "dall-e" or self.discrete_vae_type == "zhiliang":
                    input_ids = d_vae.get_codebook_indices(vae_input).flatten(1)    # B, N
                elif self.discrete_vae_type == "peco":
                    input_ids = d_vae.get_tokens(vae_input)['token']
                elif self.discrete_vae_type == "vqgan":
                    z_dummy, _, [_, _, input_ids] = d_vae.encode(vae_input)    # B*N
                    input_ids = input_ids.view(vit_input.shape[0], -1)   # B, N
                elif self.discrete_vae_type == "vit_vqgan":
                    z = model.get_tokens(vae_input)
                    # print(z['token'].shape)
                    # exit()

            if self.keep_ratio > 0.:
                mask_indices = torch.nonzero(bool_masked_pos[0])
                mask_len = mask_indices.shape[0]
                kept_mask_indices = torch.randperm(mask_len)[:int(mask_len * self.keep_ratio)]
                kept_mask_indices = mask_indices[kept_mask_indices]
                k_h, k_w = kept_mask_indices.t()
                bool_masked_pos[:, k_h, k_w] = 0

            resz_bool_masked_pos = bool_masked_pos.unsqueeze(1)

            if self.disable_interpolate:
                if self.tokenizer_stride == 8:
                    resz_bool_masked_pos = F.interpolate(resz_bool_masked_pos.float(), scale_factor=2.0)

            resz_bool_masked_pos = resz_bool_masked_pos.flatten(1).to(torch.bool)
            bool_masked_pos = bool_masked_pos.flatten(1).to(torch.bool)     # B, N

            labels = input_ids[resz_bool_masked_pos]     # num_mask_token

        if self.disable_gen_amp:
            x = self.forward_features(vit_input, bool_masked_pos)
            x = x[:, 1:]    # B, N, C
        else:
            with torch.cuda.amp.autocast():
                x = self.forward_features(vit_input, bool_masked_pos)
                x = x[:, 1:]    # B, N, C

        gen_logit = self.gen_head(x)

        if self.disable_interpolate:
            # gen_logit = rearrange(gen_logit, "b (l1 l2) v -> b v l1 l2", l1=14, l2=14)
            # gen_logit = F.pixel_shuffle(gen_logit, 2)
            # gen_logit = rearrange(gen_logit, "b v l1 l2 -> b (l1 l2) v", l1=28, l2=28)

            gen_logit = rearrange(gen_logit, "b (l1 l2) (p1 p2 v) -> b (l1 p1 l2 p2) v", l1=14, l2=14, p1=2, p2=2)

            mask_gen_logit = gen_logit[resz_bool_masked_pos]      # num_mask_token, vocab_size
        else:
            mask_gen_logit = gen_logit[bool_masked_pos]      # num_mask_token, vocab_size

        # print(torch.allclose(gen_logit1, gen_logit))
        # exit()

        gen_loss = nn.CrossEntropyLoss()(input=mask_gen_logit.float(), target=labels)

        # Replace Token
        with torch.no_grad():
            # sampling
            sampled_gen_id = gumbel_sample(mask_gen_logit, self.gumbel_sample_temperature)

            if self.random_replace:
                sampled_gen_id = torch.randint_like(sampled_gen_id, 2, self.vocab_size - 2)

            # replace
            gen_input_id = input_ids.clone()    # B, N
            gen_input_id[resz_bool_masked_pos] = sampled_gen_id  # Replace
            dis_label = (input_ids != gen_input_id).float()

            # tok to img decoding
            if self.discrete_vae_type == "vqgan":
                shape = (z_dummy.shape[0], z_dummy.shape[2], z_dummy.shape[3], z_dummy.shape[1])
                gen_input_id = d_vae.quantize.get_codebook_entry(gen_input_id.flatten(0), shape)

            if self.discrete_vae_type == "vqgan":
                recon_img = d_vae.decode(gen_input_id).float()
            else:
                with torch.cuda.amp.autocast():
                    recon_img = d_vae.decode(gen_input_id).float()
            
            if self.discrete_vae_type == "dall-e":
                recon_img = unmap_pixels(torch.sigmoid(recon_img[:, :3]))
            elif self.discrete_vae_type == "peco" or self.discrete_vae_type == "zhiliang":
                recon_img = recon_img / 255.
            elif self.discrete_vae_type == "vqgan":
                recon_img = torch.clamp(recon_img, -1., 1.)
                recon_img = (recon_img + 1.) / 2.

            if not self.disable_interpolate:
                if self.tokenizer_stride == 8:
                    recon_img = F.interpolate(recon_img, vit_input.shape[-2:], mode='bicubic', align_corners=False)

            if self.vis_gen and not self.reg_diff:
                return recon_img, recon_img

            norm_recon_img = TF.normalize(recon_img, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)

        with torch.cuda.amp.autocast():
            dis_output, dis_cls_output = self.forward_features_dis(norm_recon_img)

        device = torch.device('cuda')
        mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device)[None, :, None, None]
        std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device)[None, :, None, None]

        if self.vis_gen and self.reg_diff:
            if not self.recon_unnorm_ori:
                    unnorm_ori_images = vit_input * std + mean  # in [0, 1]
                    if self.non_overlapping_win_pixel_norm:
                        gen_per_patch_norm_ori_images = rearrange(dis_output, "b n (p c) -> b n p c", p=16 * 16)
                        unnorm_ori_images = rearrange(unnorm_ori_images, 'b c (h p1) (w p2) -> b (h w) (p1 p2) c', p1=self.patch_size, p2=self.patch_size)
                        gen_unnorm_ori_images = gen_per_patch_norm_ori_images * (unnorm_ori_images.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6) + unnorm_ori_images.mean(dim=-2, keepdim=True)
                        gen_unnorm_ori_images = rearrange(gen_unnorm_ori_images, 'b (h w) (p1 p2) c -> b c (h p1) (w p2)', h=14, w=14, p1=self.patch_size, p2=self.patch_size).clamp(0, 1)
                        return recon_img, gen_unnorm_ori_images
                    else:
                        gen_per_patch_norm_ori_images = rearrange(dis_output, "b (h w) (p1 p2 c) -> b c (h p1) (w p2)", h=14, w=14, p1=16, p2=16).flatten(-2)
        
                        refl_pad = nn.ReflectionPad2d((self.patch_size // 4 - 1, self.patch_size // 4, self.patch_size // 4 - 1, self.patch_size // 4))     # 7, 8, 7, 8
                        unfold = nn.Unfold(kernel_size=(self.patch_size // 2, self.patch_size // 2))   # 8, 8

                        unnorm_ori_images_pad = refl_pad(unnorm_ori_images)
                        unnorm_ori_images_pad_unfold = unfold(unnorm_ori_images_pad)
                        unnorm_ori_images = rearrange(unnorm_ori_images_pad_unfold, 'b (c p1 p2) l -> b c (p1 p2) l', p1=self.patch_size // 2, p2=self.patch_size // 2)    # bsz, 3, 16*16, 224*224

                        gen_unnorm_ori_images = gen_per_patch_norm_ori_images * (unnorm_ori_images.var(dim=-2, unbiased=True).sqrt() + 1e-6) + unnorm_ori_images.mean(dim=-2)
                        gen_unnorm_ori_images = gen_unnorm_ori_images.view_as(vit_input).clamp(0, 1)
                        return recon_img, gen_unnorm_ori_images


            if self.recon_unnorm_ori:
                gen_unnorm_ori_images= rearrange(dis_output, 'b (h w) (p1 p2 c) -> b c (h p1) (w p2)', h=14, w=14, p1=self.patch_size, p2=self.patch_size)
                gen_unnorm_ori_images = (gen_unnorm_ori_images * std + mean).clamp(0, 1)  # in [0, 1]
                return recon_img, gen_unnorm_ori_images

        if self.reg_diff:
            if not self.recon_unnorm_ori:
                with torch.no_grad():      
                    unnorm_ori_images = vit_input * std + mean  # in [0, 1]

                    if self.non_overlapping_win_pixel_norm:
                        unnorm_ori_images = rearrange(unnorm_ori_images, 'b c (h p1) (w p2) -> b (h w) (p1 p2) c', p1=self.patch_size, p2=self.patch_size)
                        per_patch_norm_ori_images = (unnorm_ori_images - unnorm_ori_images.mean(dim=-2, keepdim=True)
                            ) / (unnorm_ori_images.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6)
                        per_patch_norm_ori_images = rearrange(per_patch_norm_ori_images, 'b n p c -> b n (p c)')

                        if not self.recon_ori:  # recon res
                            unnorm_recon_img = rearrange(recon_img, 'b c (h p1) (w p2) -> b (h w) (p1 p2) c', p1=self.patch_size, p2=self.patch_size)
                            per_patch_norm_recon_images = (unnorm_recon_img - unnorm_recon_img.mean(dim=-2, keepdim=True)
                                ) / (unnorm_recon_img.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6)
                            per_patch_norm_recon_images = rearrange(per_patch_norm_recon_images, 'b n p c -> b n (p c)')

                    else: # sliding win pixel norm
                        refl_pad = nn.ReflectionPad2d((self.patch_size // 4 - 1, self.patch_size // 4, self.patch_size // 4 - 1, self.patch_size // 4))     # 7, 8, 7, 8
                        unfold = nn.Unfold(kernel_size=(self.patch_size // 2, self.patch_size // 2))   # 8, 8

                        unnorm_ori_images_pad = refl_pad(unnorm_ori_images)
                        unnorm_ori_images_pad_unfold = unfold(unnorm_ori_images_pad)
                        unnorm_ori_images_pad_unfold = rearrange(unnorm_ori_images_pad_unfold, 'b (c p1 p2) l -> b c (p1 p2) l', p1=self.patch_size // 2, p2=self.patch_size // 2)    # bsz, 3, 16*16, 224*224
                        per_patch_norm_ori_images = (unnorm_ori_images.flatten(-2) - unnorm_ori_images_pad_unfold.mean(dim=-2)
                            ) / (unnorm_ori_images_pad_unfold.var(dim=-2, unbiased=True).sqrt() + 1e-6)
                        per_patch_norm_ori_images = per_patch_norm_ori_images.view_as(unnorm_ori_images)
                        per_patch_norm_ori_images = rearrange(per_patch_norm_ori_images, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=self.patch_size, p2=self.patch_size)

                        if not self.recon_ori:  # recon res
                            unnorm_recon_images_pad = refl_pad(recon_img)
                            unnorm_recon_images_pad_unfold = unfold(unnorm_recon_images_pad)
                            unnorm_recon_images_pad_unfold = rearrange(unnorm_recon_images_pad_unfold, 'b (c p1 p2) l -> b c (p1 p2) l', p1=self.patch_size // 2, p2=self.patch_size // 2)
                            per_patch_norm_recon_images = (recon_img.flatten(-2) - unnorm_recon_images_pad_unfold.mean(dim=-2)
                                ) / (unnorm_recon_images_pad_unfold.var(dim=-2, unbiased=True).sqrt() + 1e-6)
                            per_patch_norm_recon_images = per_patch_norm_recon_images.view_as(recon_img)
                            per_patch_norm_recon_images = rearrange(per_patch_norm_recon_images, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=self.patch_size, p2=self.patch_size)

            if self.recon_unnorm_ori:
                vit_input_= rearrange(vit_input, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=self.patch_size, p2=self.patch_size)
                dis_target = vit_input_
            elif self.recon_ori:
                dis_target = per_patch_norm_ori_images
            else:   # recon res
                dis_target = per_patch_norm_recon_images - per_patch_norm_ori_images

            # dis_loss = nn.L1Loss()(dis_output.float(), dis_target)
            # dis_loss = nn.MSELoss()(dis_output.float(), dis_target)

            if dis_cls_output is None:
                dis_loss = nn.L1Loss()(dis_output.float(), dis_target) + nn.MSELoss()(dis_output.float(), dis_target)

                with torch.no_grad():
                    mlm_acc = (mask_gen_logit.max(-1)[1] == labels).float().mean().item()
                    dis_acc = 0
                    dis_mask_part_acc = 0
                    dis_other_part_acc = 0
                    dis_replace_acc = 0
                    dis_non_replace_acc = 0

                    dis_target_mean = torch.mean(dis_target)
                    dis_target_var = torch.var(dis_target)

            else:
                # dis_cls_output = dis_cls_output.reshape_as(dis_label)
                dis_cls_output = rearrange(dis_cls_output, 'b (h w) (p1 p2 c) -> b (h p1 w p2 c)', h=14, w=14, p1=2, p2=2)
                dis_loss = 2 * F.binary_cross_entropy_with_logits(dis_cls_output.float(), dis_label) + nn.L1Loss()(dis_output.float(), dis_target) + nn.MSELoss()(dis_output.float(), dis_target)

                with torch.no_grad():
                    mlm_acc = (mask_gen_logit.max(-1)[1] == labels).float().mean().item()
                    dis_cls_output_hard = torch.round((torch.sign(dis_cls_output) + 1.0) * 0.5)
                    dis_acc = (dis_cls_output_hard == dis_label).float().mean().item()

                    # TODO: Fix replace acc
                    dis_mask_part_acc = (
                        dis_cls_output_hard[resz_bool_masked_pos] == 1).float().mean().item()
                    dis_other_part_acc = (
                        dis_cls_output_hard[~resz_bool_masked_pos] == 0).float().mean().item()

                    dis_replace_acc = (
                        dis_cls_output_hard[dis_label.bool()] == 1).float().mean().item()
                    dis_non_replace_acc = (
                        dis_cls_output_hard[~dis_label.bool()] == 0).float().mean().item()

                    dis_target_mean = torch.mean(dis_target)
                    dis_target_var = torch.var(dis_target)
            
        else:
            # dis_output_ = dis_output.reshape_as(dis_label)
            dis_output = rearrange(dis_output, 'b (h w) (p1 p2 c) -> b (h p1 w p2 c)', h=14, w=14, p1=2, p2=2)
            dis_loss = F.binary_cross_entropy_with_logits(dis_output.float(), dis_label)

            with torch.no_grad():
                mlm_acc = (mask_gen_logit.max(-1)[1] == labels).float().mean().item()
                dis_output_hard = torch.round((torch.sign(dis_output) + 1.0) * 0.5)
                dis_acc = (dis_output_hard == dis_label).float().mean().item()

                # TODO: Fix replace acc
                dis_mask_part_acc = (
                    dis_output_hard[resz_bool_masked_pos] == 1).float().mean().item()
                dis_other_part_acc = (
                    dis_output_hard[~resz_bool_masked_pos] == 0).float().mean().item()

                dis_replace_acc = (
                    dis_output_hard[dis_label.bool()] == 1).float().mean().item()
                dis_non_replace_acc = (
                    dis_output_hard[~dis_label.bool()] == 0).float().mean().item()

                dis_target_mean = 0
                dis_target_var = 0

        return gen_loss, dis_loss * self.dis_loss_weight, mlm_acc, dis_acc, dis_mask_part_acc, dis_other_part_acc, dis_replace_acc, dis_non_replace_acc, dis_target_mean, dis_target_var


@register_model
def beit_e_base_patch16_224_voc8k(pretrained=False, **kwargs):
    _ = kwargs.pop("num_classes")
    model = VisionTransformerForMaskedImageModeling(
        patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(LayerNormWithForceFP32, eps=1e-6), 
        **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.load(
            kwargs["init_ckpt"], map_location="cpu"
        )
        model.load_state_dict(checkpoint["model"])
    return model


@register_model
def beit_e_small_patch16_224_voc8k(pretrained=False, **kwargs):
    _ = kwargs.pop("num_classes")
    model = VisionTransformerForMaskedImageModeling(
        patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(LayerNormWithForceFP32, eps=1e-6), 
        **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.load(
            kwargs["init_ckpt"], map_location="cpu"
        )
        model.load_state_dict(checkpoint["model"])
    return model


@register_model
def beit_e_base_patch16_224_voc8k_debug(pretrained=False, **kwargs):

    model = VisionTransformerForMaskedImageModeling(
        patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(LayerNormWithForceFP32, eps=1e-6),
        **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.load(
            kwargs["init_ckpt"], map_location="cpu"
        )
        model.load_state_dict(checkpoint["model"])
    return model



@register_model
def beit_e_large_patch16_224_8k_vocab(pretrained=False, **kwargs):
    _ = kwargs.pop("num_classes")
    model = VisionTransformerForMaskedImageModeling(
        patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.load(
            kwargs["init_ckpt"], map_location="cpu"
        )
        model.load_state_dict(checkpoint["model"])
    return model



@register_model
def beit_e_resnet50_patch16_224_voc8k(pretrained=False, **kwargs):
    _ = kwargs.pop("num_classes")
    model = CNNForMaskedImageModeling(
        patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, cnn="resnet50",
        norm_layer=partial(LayerNormWithForceFP32, eps=1e-6),
        **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.load(
            kwargs["init_ckpt"], map_location="cpu"
        )
        model.load_state_dict(checkpoint["model"])
    return model



